from PyQt5 import QtWidgets, uic, QtCore, QtGui
import pyqtgraph as pg
import sys
import numpy as np
import os
import Initialization
import Initialization_test
import Movements
import Close_connection
import Spectral_calibration
import Find_scan_range
import Data_acquisition
import Get_calibrated_position_axis
import Get_spectrum_dft
from datetime import datetime


class MainWindow(QtWidgets.QMainWindow):

    def __init__(self, *args, **kwargs):
        super(MainWindow, self).__init__(*args, **kwargs)
        pg.setConfigOption('background', 'w')
        # Load the UI Page
        uic.loadUi('Interface.ui', self)

        items = os.listdir(".")
        for names in items:
            if names.endswith("parameters_scale.txt"):
                filename = names
        self.scale = np.genfromtxt(filename, delimiter='\t')

        items = os.listdir(".")
        for names in items:
            if names.endswith("parameters_cal.txt"):
                filename = names

        ref_wave = np.genfromtxt(filename, delimiter='\t')
        ref1_wave = ref_wave[0, :]
        self.low_wave = int(ref1_wave[0])
        self.high_wave = int(ref1_wave[-1])

        self.position_axis = np.array([])

        self.resolution = float(self.resolution_txt.text())
        self.apodization_width = float(self.apodization_txt.text())
        self.sampling_factor = float(self.sampling_txt.text())
        self.spectral_points = float(self.spectral_point_txt.text())

        self.graphWidget.plotItem.showGrid(True, True, 0.7)
        self.graphMat.plotItem.showGrid(True, True, 0.7)
        self.graphMat_2.plotItem.showGrid(True, True, 0.7)
        self.stop_btn.clicked.connect(self.stop_scan)
        self.start_btn.clicked.connect(self.scan)
        self.connect_btn.clicked.connect(self.connect_gemini)
        self.go_home.clicked.connect(self.move_home)
        self.move_abs.clicked.connect(self.move_absolute)
        self.value_abs.returnPressed.connect(self.move_absolute)
        self.move_left.clicked.connect(self.move_lft)
        self.move_right.clicked.connect(self.move_rgt)
        self.exit_btn.clicked.connect(self.disconnect_controller)
        self.save_btn.clicked.connect(self.saving)

        self.save_btn.setIcon(QtGui.QIcon("save.ico"))
        self.save_btn.setIconSize(QtCore.QSize(20, 40))

        self.progress.setValue(0)

        self.start_wave_txt.returnPressed.connect(self.scan_range_plot)
        self.end_wave_txt.returnPressed.connect(self.scan_range_plot)
        self.apodization_txt.returnPressed.connect(self.scan_range_plot)
        self.resolution_txt.returnPressed.connect(self.scan_range)

        self.graphWidget.setLabel('left', "Amplitude")
        self.graphWidget.setLabel('bottom', "Position[mm]")

        self.graphMat.setLabel('left', "Amplitude")
        self.graphMat.setLabel('bottom', "Wavelength [nm]")

        self.graphMat_2.setLabel('left', "Amplitude")
        self.graphMat_2.setLabel('bottom', "Position[mm]")

        self.start_btn.setIcon(QtGui.QIcon("play.png"))
        self.start_btn.setIconSize(QtCore.QSize(40, 60))
        self.stop_btn.setIcon(QtGui.QIcon("stop.png"))
        self.stop_btn.setIconSize(QtCore.QSize(20, 20))

        self.move_right.setIcon(QtGui.QIcon("right.ico"))
        self.move_right.setIconSize(QtCore.QSize(20, 40))
        self.move_left.setIcon(QtGui.QIcon("left.ico"))
        self.move_left.setIconSize(QtCore.QSize(20, 40))
        self.move_right.setLayoutDirection(QtCore.Qt.RightToLeft)

    def scan_range(self):
        start_wave = float(self.start_wave_txt.text())
        end_wave = float(self.end_wave_txt.text())

        if start_wave < self.low_wave:
            start_wave = self.low_wave
            self.start_wave_txt.setText(str(self.low_wave))
        if end_wave > self.high_wave:
            end_wave=self.high_wave
            self.end_wave_txt.setText(str(self.high_wave))

        resolution = float(self.resolution_txt.text())
        sampling_factor = float(self.sampling_txt.text())
        P_freq2wave, P_wave2freq = Spectral_calibration.spectral_calibration()
        lambda_mean = (start_wave + end_wave) / 2
        start_position, end_position = Find_scan_range.scan_range(lambda_mean, resolution)

        freq = np.polyval(P_wave2freq, 1 / start_wave)
        freq_mm = 1 / freq
        number_steps = np.ceil((abs(start_position - end_position)) / (freq_mm / sampling_factor))
        self.start_position_txt.setText(str(round(start_position,3)))
        self.end_position_txt.setText(str(round(end_position,3)))
        self.number_steps_txt.setText(str(int(number_steps)))

    def scan_range_plot(self):
        start_wave = float(self.start_wave_txt.text())
        end_wave = float(self.end_wave_txt.text())

        if start_wave < self.low_wave:
            start_wave = self.low_wave
            self.start_wave_txt.setText(str(self.low_wave))
        if end_wave > self.high_wave:
            end_wave=self.high_wave
            self.end_wave_txt.setText(str(self.high_wave))

        resolution = float(self.resolution_txt.text())
        sampling_factor = float(self.sampling_txt.text())
        apodization_width = float(self.apodization_txt.text())
        spectral_points = float(self.spectral_point_txt.text())
        P_freq2wave, P_wave2freq = Spectral_calibration.spectral_calibration()
        lambda_mean = (start_wave + end_wave) / 2
        start_position, end_position = Find_scan_range.scan_range(lambda_mean, resolution)

        if self.position_axis.size > 0:
            calibrated_positions = Get_calibrated_position_axis.get_calibrated_position_axis(self.position_axis, self.scale)
            spectrum, wave, freq, interf_apo = Get_spectrum_dft.get_spectrum(self.signal, calibrated_positions, start_wave, end_wave,
                                                                             spectral_points, apodization_width)
            pen = pg.mkPen(color=(255, 0, 0))
            self.graphMat.plot(wave, spectrum, clear=True, pen=pen)
            self.graphMat_2.plot(calibrated_positions, interf_apo, clear=True, pen=pen)
            self.tabWidget.setCurrentIndex(1)

        freq = np.polyval(P_wave2freq, 1 / start_wave)
        freq_mm = 1 / freq
        number_steps = np.ceil((abs(start_position - end_position)) / (freq_mm / sampling_factor))
        self.start_position_txt.setText(str(round(start_position,3)))
        self.end_position_txt.setText(str(round(end_position,3)))
        self.number_steps_txt.setText(str(int(number_steps)))

    def connect_gemini(self):
        self.system_index, self.channel_index=Initialization_test.initialization()

    def move_home(self):
        position = 0
        Movements.move_absolute(self.system_index, self.channel_index, int(position))

    def move_absolute(self):
        position = float(self.value_abs.text()) * 1000000
        Movements.move_absolute(self.system_index, self.channel_index, int(position))

    def move_lft(self):
        position = float(self.value_rel.text()) * 1000000
        Movements.move_relative(self.system_index, self.channel_index, -int(position))

    def move_rgt(self):
        position = float(self.value_rel.text()) * 1000000
        Movements.move_relative(self.system_index, self.channel_index, int(position))

    def disconnect_controller(self):
        Close_connection.close_connection(self.system_index)

    def saving(self):
        file_name = self.saving_name.text()
        if not os.path.exists('Measurements'):
            os.mkdir('Measurements')

        if not os.path.exists("Measurements/" + datetime.today().strftime('%Y-%m-%d')):
            os.mkdir("Measurements/" + datetime.today().strftime('%Y-%m-%d'))

        np.savetxt("Measurements/" + datetime.today().strftime('%Y-%m-%d') + "/" + file_name + ".txt", np.stack((self.position_axis, self.signal)))

    def scan(self):
        self.tabWidget.setCurrentIndex(0)
        self.stop_flag=False
        self.graphWidget.clear()
        start_wave=float(self.start_wave_txt.text())
        end_wave = float(self.end_wave_txt.text())

        if start_wave < self.low_wave:
            start_wave = self.low_wave
            self.start_wave_txt.setText(str(self.low_wave))
        if end_wave > self.high_wave:
            end_wave = self.high_wave
            self.end_wave_txt.setText(str(self.high_wave))

        self.resolution = float(self.resolution_txt.text())
        self.apodization_width = float(self.apodization_txt.text())
        self.sampling_factor = float(self.sampling_txt.text())
        self.spectral_points=float(self.spectral_point_txt.text())

        P_freq2wave, P_wave2freq = Spectral_calibration.spectral_calibration()
        lambda_mean = (start_wave + end_wave) / 2
        start_position, end_position = Find_scan_range.scan_range(lambda_mean, self.resolution)

        freq = np.polyval(P_wave2freq, 1 / start_wave)
        freq_mm = 1 / freq
        number_steps = np.ceil((abs(start_position - end_position)) / (freq_mm / self.sampling_factor))
        step = (end_position - start_position) / number_steps
        self.start_position_txt.setText(str(round(start_position, 3)))
        self.end_position_txt.setText(str(round(end_position, 3)))
        self.number_steps_txt.setText(str(number_steps))

        next_position = start_position
        Movements.move_absolute(self.system_index, self.channel_index, int(next_position * 1000000))
        data_acquired = [Data_acquisition.data_acquisition()]
        positions = [start_position]
        data_acquired = np.array(data_acquired)
        data_acquired = data_acquired.flatten()
        positions = np.array(positions)
        positions = positions.flatten()

        '''Data acquisition cycle'''

        for i in range(0, int(number_steps)):
            if not self.stop_flag:
                next_position = next_position + step
                Movements.move_absolute(self.system_index, self.channel_index, int(next_position * 1000000))
                while True:
                    status = Movements.get_status(self.system_index, self.channel_index)
                    if status.value == 0 or status.value == 3:
                        break
                self.progress.setValue(i / (number_steps - 1) * 100)
                data_acquired = np.append(data_acquired, Data_acquisition.data_acquisition())  # N.B.: substitute the code inside the parenthesis with your data acquisition system, this is for NI's DAQ card
                positions = np.append(positions, next_position)

                data_acquired_plot = data_acquired
                positions_plot = positions
                self.update_plot(positions_plot, data_acquired_plot)
                if len(positions_plot)>200:
                    positions_plot=np.delete(positions_plot,0, axis=0)
                    data_acquired_plot=np.delete(data_acquired_plot, 0, axis=0)
                pg.QtGui.QApplication.processEvents()
            else:
                return

        self.signal = np.array(data_acquired)
        self.signal = self.signal.flatten()
        self.position_axis = np.array(positions)

        '''Calculation part: calibration of position axis, fourier transform'''

        calibrated_positions = Get_calibrated_position_axis.get_calibrated_position_axis(self.position_axis, self.scale)
        spectrum, wave, freq, interf_apo = Get_spectrum_dft.get_spectrum(self.signal, calibrated_positions, start_wave, end_wave,
                                                             self.spectral_points, self.apodization_width)

        pen = pg.mkPen(color=(255, 0, 0))
        self.graphMat.plot(wave, spectrum, clear=True, pen=pen)
        self.graphMat_2.plot(calibrated_positions, interf_apo, clear=True, pen=pen)
        self.tabWidget.setCurrentIndex(1)

    def stop_scan(self):
        self.stop_flag=True

    def update_plot(self,X,Y):
        pen = pg.mkPen(color=(255, 0, 0))
        self.graphWidget.plot(X, Y,clear=True, pen=pen)


def main():
    app = QtWidgets.QApplication(sys.argv)
    main = MainWindow()
    main.show()
    sys.exit(app.exec_())


if __name__ == '__main__':
    main()